Skip to content

[wip] Fix loss reduction#909

Closed
justinvyu wants to merge 3 commits intoNovaSky-AI:mainfrom
justinvyu:fix_loss_reduction
Closed

[wip] Fix loss reduction#909
justinvyu wants to merge 3 commits intoNovaSky-AI:mainfrom
justinvyu:fix_loss_reduction

Conversation

@justinvyu
Copy link
Copy Markdown
Contributor

No description provided.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
self._micro_batches_accumulated += 1

# Track total tokens accumulated to compute average loss per token during optim_step
self._total_tokens_accumulated += micro_batch.attention_mask.sum().item()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should actually be loss_mask

@justinvyu
Copy link
Copy Markdown
Contributor Author

close in favor of #925

@justinvyu justinvyu closed this Jan 23, 2026
erictang000 pushed a commit that referenced this pull request Mar 31, 2026
…ization and fix `token_mean` reduction strategy (#1296)

**This is a breaking change for the default `token_mean` loss behavior,
as well as observed `policy_loss` metrics!** See the "Differences in
reported loss metric magnitudes" section below.

## Summary

* Change `reduce_loss()` to always returns a simple masked sum (`(loss *
mask).sum()`). To achieve different reduction strategies, we pre-scale
the advantages before they enter the loss function, which also aligns
with how Tinker's API handles it.
* Scales the loss by the DP size before calling `backward()` to
counteract the default data parallel mean gradient all-reduce across
workers to do a sum instead.
* Fixes the `token_mean` loss reduction method to take a mean across all
tokens in the minibatch rather than averaging across microbatches.
Allows running with the old loss reduction with the `token_mean_legacy`
config.

## Loss reduction strategies 
* **Option 1: token_mean**
* Average loss per token across the entire mini-batch.
* **This is the fixed version where the denominator is the total token
count across the full mini-batch, so the gradient is independent of how
the minibbatch is split into micro-batches.**
* **Option 1b: token_mean_legacy**
* Compute token-mean loss within each micro-batch, then average across
micro-batches.
  * This reproduces the `token_mean` behavior before this PR.
* **The problem: if micro-batches have different token counts, the
effective weighting differs from a true global token mean. This is also
less usable since changing micro batch size affects the loss and the
training dynamics.**
* Kept as a fallback in case of performance regressions — we should
remove this down the line.
                                                
* **Option 2: sequence_mean**
* Compute per-token loss within each sequence, average across sequences.
* This is unchanged and is just implemented via advantage normalization
instead.
* **Option 3: seq_mean_token_sum_norm**
* Dr. GRPO style — normalize by a fixed constant to avoid any
length-dependent weighting.
* This is unchanged and is just implemented via advantage normalization
instead.

## Mean all-reduce -> sum all-reduce

We need the loss to be summed across microbatches and data parallel
workers:
* DDP/FSDP defaults to a mean all-reduce for gradients across workers.
This PR counteracts this by multiplying by the DP world size in order to
keep the loss sum across data parallel groups.
* Megatron also does a similar mean reduction across microbatches and
workers, so we counteract this by multiplying by num microbatches and DP
size to achieve the sum.

## Difference in reported loss metric magnitudes

You will observe that the loss metric reported has a different magnitude
compared to your older runs. This is beacuse the old token_mean
implementation was somewhere between a true token mean and a sequence
mean due to per-micro-batch normalization (ex: micro_batch_size=1 was
equivalent to sequence mean).

The new `token_mean` is a proper minibatch token mean, while
`sequence_mean` weights every sequence equally regardless of length.
When comparing the loss produced by different reduction methods computed
on the same advantages, from a real run:

```
  token_mean (new):  0.322   — every token weighted equally across the mini-batch
  token_mean (old):  0.065   — mean of per-micro-batch token means, where micro_batch_size=4
  sequence_mean:     0.00098 — every sequence weighted equally
```
The old token_mean gave each micro-batch equal weight rather than each
token, so its scale depended on how advantages were distributed across
micro-batches. The new implementation is invariant to micro-batch size.

Note that `token_mean_legacy` reports the old metrics still, and the
`sequence_mean` and `seq_mean_token_sum_norm` modes also match exactly.
See this
[comment](#1296 (comment))
for more details.

## Tinker compatibility

Here was the first attempt at fixing the loss reduction across
microbatches: #909

This method was to track total tokens and then do one big normalization
at the `optim_step` in order to get an average per-token loss. But, we
decided to align with Tinker's way of just summing up the loss at the
end, and pushing any loss normalization to the user's advantage
calculation.

The benefit is that users have full control of customizing their loss
reduction strategy, rather than having it happen in our opaque
`forward_backward`, `optim_step` implementation which would require some
configuration argument that diverges from tinker's API. For example, we
would need to add a config somewhere to determine how to average/sum the
loss:

```python
client.forward_backward(...)
client.optim_step(..., loss_reduction="token_mean")  # no longer tinker compatible
```

The current PR aligns with [Tinker
semantics](https://tinker-docs.thinkingmachines.ai/losses#proximal-policy-optimization-ppo):

> Notice that for all objectives we sum the token-level losses over the
sequence length unlike some other loss implementations. If you would
like to explore different aggregation schemes, you can include that in
the advantage tensor computation.

Example for `loss_reduction="token_mean"`:
* Move the `1/num_minibatch_tokens` normalization into the advantage:
`loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens)
) / num_minibatch_tokens`
* -> `sum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in
range(num_minibatch_tokens) )`

## Learning curve comparisons before/after the PR

### FSDP
([wandb](https://wandb.ai/justinvyu/dapo_aime/reports/FSDP-experiments--VmlldzoxNjIyNDExMQ))

<img width="817" height="664" alt="Screenshot 2026-03-20 at 3 29 10 PM"
src="https://github.com/user-attachments/assets/be3ac108-b5b6-4408-a270-884b7fd4bce3"
/>

### Megatron
([wandb](https://wandb.ai/justinvyu/dapo_aime/reports/Megatron-loss-reduction--VmlldzoxNjIyNDE2Ng))

#### 1.7B:
<img width="1340" height="637" alt="Screenshot 2026-03-20 at 3 29 40 PM"
src="https://github.com/user-attachments/assets/632d0c21-5b9c-48f3-aba9-d65ed00b13df"
/>

#### 30B lora:

master baseline:
<img width="1629" height="698" alt="Screenshot 2026-03-20 at 3 33 29 PM"
src="https://github.com/user-attachments/assets/54c07117-9e40-4551-91c0-78ae5f25e347"
/>

`token_mean_legacy` + fixed `token_mean`:

<img width="1149" height="660" alt="Screenshot 2026-03-24 at 11 16
56 AM"
src="https://github.com/user-attachments/assets/ef91b399-5f15-4db2-81ea-1e14c5932406"
/>


<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1296"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant